#!/usr/bin/env python3
import asyncio
import argparse
import json
import os
import re
import numpy as np
from collections import Counter
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any, Union
import aiohttp

class IAP_StrategyQASolver:
    """Instance-Adaptive Prompting (IAP) StrategyQA solver"""
    
    def __init__(self, strategy="mv"):
        self.model = "your model"
        self.base_url = "your base_url"
        self.token_counts = [0, 0]
        self.stats = {
            "total_problems": 0,
            "correct_answers": 0,
            "incorrect_answers": 0,
            "accuracy": 0.0
        }
        self.strategy = strategy  # "mv" for majority vote, "ss" for sequential substitution
        
        # Define multiple prompt candidates for StrategyQA
        self.prompts = [
            "Let's think step by step.",  # #1
            "First,",  # #2
            "The answer is after the proof.",  # #3
            "Before we dive into the answer,",  # #4
            "Let's solve this problem by splitting it into steps.",  # #5
            "Let's think about this logically.",  # #6
            "It's a beautiful day.",  # #7
            "Don't think. Just feel.",  # #8
            "By the fact that the earth is round,",  # #9
        ]
        
        # Hyperparameters for saliency score calculation
        self.lambda_qp = 0.4  # weight for question-to-prompt flow
        self.lambda_qr = 0.4  # weight for question-to-rationale flow
        self.lambda_pr = 0.2  # weight for prompt-to-rationale flow
        
        # Thresholds for good/bad reasoning (empirically determined)
        self.threshold = 5.5e-6
        
    async def generate(self, prompt: str, question: str, facts: str) -> Tuple[str, Dict]:
        """Call local Ollama API and estimate saliency scores"""
        try:
            full_prompt = f"Question: {question}\nFacts:\n{facts}\n{prompt}"
            
            async with aiohttp.ClientSession() as session:
                payload = {
                    "model": self.model,
                    "messages": [{"role": "user", "content": full_prompt}],
                    "temperature": 0.3,
                    "max_tokens": 8000,
                    "top_p": 0.9
                }
                
                async with session.post(
                    f"{self.base_url}/chat/completions",
                    json=payload,
                    timeout=aiohttp.ClientTimeout(total=120)
                ) as response:
                    resp = await response.json()
                    content = resp["choices"][0]["message"]["content"]
                    
                    # Estimate token usage
                    input_tokens = len(full_prompt) // 4
                    output_tokens = len(content) // 4
                    self.token_counts[0] += input_tokens
                    self.token_counts[1] += output_tokens
                    
                    # Simulate saliency scores (in a real implementation, this would come from model internals)
                    saliency_scores = self._estimate_saliency_scores(full_prompt, content)
                    
                    return content, saliency_scores
        except Exception as e:
            print(f"LLM Error: {str(e)}")
            raise
    
    def _estimate_saliency_scores(self, prompt: str, response: str) -> Dict:
        """Estimate saliency scores (simplified version - in practice would use model internals)"""
        # Calculate basic quality indicators
        has_step_by_step = "step" in response.lower()
        has_reasoning = len(response.split()) > 20  # Longer responses tend to have more reasoning
        answer_quality = 1 if self._extract_answer(response) is not None else 0
        
        # Simulate saliency scores based on these indicators
        qp_score = np.random.normal(0.5, 0.2) * (1 if has_step_by_step else 0.5)
        qr_score = np.random.normal(0.5, 0.2) * (1 if has_reasoning else 0.3)
        pr_score = np.random.normal(0.3, 0.1) * answer_quality
        
        return {
            "qp": max(0, qp_score * 1e-5),  # question-to-prompt
            "qr": max(0, qr_score * 1e-5),  # question-to-rationale
            "pr": max(0, pr_score * 1e-5),  # prompt-to-rationale
            "combined": self.lambda_qp * qp_score + self.lambda_qr * qr_score + self.lambda_pr * pr_score
        }
    
    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer (True/False/Unknown) from response text with multiple patterns"""
        patterns = [
            r'Final Answer:\s*(true|false|unknown)',  # Final Answer: true
            r'Answer:\s*(true|false|unknown)',        # Answer: false
            r'Correct Answer:\s*(true|false|unknown)', # Correct Answer: true
            r'\(?(true|false|unknown)\)?',            # (true) or false
            r'\[?(true|false|unknown)\]?',            # [true] or false
            r'\{?(true|false|unknown)\}?',            # {true} or false
            r'\b(true|false|unknown)\b',              # standalone true/false/unknown
            r'The conclusion is:\s*(true|false|unknown)',  # The conclusion is: true
            r'Therefore, the answer is:\s*(true|false|unknown)',  # Therefore, the answer is: false
        ]
        
        for pattern in patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                return match.group(1).lower()
        
        # Fallback: look for the last occurrence of true/false/unknown in the text
        last_option_match = re.findall(r'(true|false|unknown)', text, re.IGNORECASE)
        if last_option_match:
            return last_option_match[-1].lower()
        
        return None
    
    async def solve_with_prompt(self, question: str, facts: str, prompt: str) -> Dict[str, Any]:
        """Solve a problem with a specific prompt"""
        modified_prompt = (
            "Please provide your final answer as either 'True' or 'False' at the very end."
            f"{prompt}"
        )
        
        response, saliency_scores = await self.generate(modified_prompt, question, facts)
        answer = self._extract_answer(response)
        
        return {
            "prompt": prompt,
            "response": response,
            "answer": answer,
            "saliency_scores": saliency_scores,
            "tokens": self.token_counts.copy()
        }
    
    async def solve_problem_iap_mv(self, question: str, facts: str) -> Dict[str, Any]:
        """Solve using Majority Vote strategy"""
        results = []
        
        # Try all prompts
        for prompt in self.prompts:
            result = await self.solve_with_prompt(question, facts, prompt)
            results.append(result)
        
        # Select top 3 based on combined saliency score
        top_results = sorted(results, key=lambda x: x["saliency_scores"]["combined"], reverse=True)[:3]
        
        # Majority vote on answers
        answer_counts = Counter()
        for res in top_results:
            ans = res["answer"]
            if ans:
                answer_counts[ans] += 1
        
        if answer_counts:
            final_answer = answer_counts.most_common(1)[0][0]
        else:
            final_answer = None
        
        return {
            "strategy": "mv",
            "responses": [r["response"] for r in results],
            "answers": [r["answer"] for r in results],
            "saliency_scores": [r["saliency_scores"] for r in results],
            "final_answer": final_answer,
            "top_prompts": [r["prompt"] for r in top_results],
            "tokens": sum(r["tokens"][0] for r in results),  # Total input tokens
        }
    
    async def solve_problem_iap_ss(self, question: str, facts: str) -> Dict[str, Any]:
        """Solve using Sequential Substitution strategy"""
        for prompt in self.prompts:
            result = await self.solve_with_prompt(question, facts, prompt)
            
            # Check if saliency scores meet threshold
            if result["saliency_scores"]["combined"] >= self.threshold:
                return {
                    "strategy": "ss",
                    "prompt": prompt,
                    "response": result["response"],
                    "answer": result["answer"],
                    "saliency_scores": result["saliency_scores"],
                    "tokens": result["tokens"],
                    "prompts_tried": self.prompts.index(prompt) + 1
                }
        
        # If no prompt meets threshold, return the last one
        return {
            "strategy": "ss",
            "prompt": self.prompts[-1],
            "response": result["response"],
            "answer": result["answer"],
            "saliency_scores": result["saliency_scores"],
            "tokens": result["tokens"],
            "prompts_tried": len(self.prompts)
        }
    
    async def solve_problem(self, question: str, facts: str) -> Dict[str, Any]:
        """Solve problem using selected IAP strategy"""
        if self.strategy == "mv":
            return await self.solve_problem_iap_mv(question, facts)
        else:
            return await self.solve_problem_iap_ss(question, facts)
    
    async def load_problems(self, dataset_path: str, start_idx: int, end_idx: int) -> List[Dict]:
        """Load StrategyQA problems from dataset"""
        try:
            with open(dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []
    
    def _verify_answer(self, problem: Dict[str, Any], selected_answer: str) -> bool:
        """Verify if selected answer matches correct option"""
        correct_answer = str(problem.get("answer", "")).lower()
        return str(selected_answer).lower() == correct_answer if selected_answer else False
    
    def update_stats(self, is_correct: bool):
        """Update statistics"""
        self.stats["total_problems"] += 1
        if is_correct:
            self.stats["correct_answers"] += 1
        else:
            self.stats["incorrect_answers"] += 1
        
        if self.stats["total_problems"] > 0:
            self.stats["accuracy"] = (
                self.stats["correct_answers"] / self.stats["total_problems"] * 100
            )

async def main():
    parser = argparse.ArgumentParser(description="IAP StrategyQA Solver")
    parser.add_argument("--start", type=int, default=0, help="Start index in dataset")
    parser.add_argument("--end", type=int, default=1, help="End index in dataset")
    parser.add_argument("--dataset", type=str, default="StrategyQA.json", help="Path to dataset")
    parser.add_argument("--strategy", choices=["mv", "ss"], default="mv", 
                       help="IAP strategy: mv (majority vote) or ss (sequential substitution)")
    args = parser.parse_args()
    
    # Create output directory if it doesn't exist
    os.makedirs("log/StrategyQA_iap", exist_ok=True)
    
    solver = IAP_StrategyQASolver(strategy=args.strategy)
    problems = await solver.load_problems(args.dataset, args.start, args.end)
    results = []
    
    for idx, problem in enumerate(problems, args.start):
        if "question" not in problem:
            print(f"\n{'='*50}\nSkipping problem {idx}: No 'question' field\n{'='*50}")
            continue
        
        print(f"\n{'='*50}\nProcessing problem {idx}: {problem['question'][:50]}...\n{'='*50}")
        
        # Reset token counts for each problem
        solver.token_counts = [0, 0]
        
        # Prepare facts
        facts = "\n".join([f"- {fact}" for fact in problem.get("facts", [])])
        
        result = await solver.solve_problem(problem["question"], facts)
        
        # Prepare verification
        correct_answer = str(problem.get("answer", "")).lower()
        is_correct = False
        
        if correct_answer and result.get("final_answer" if args.strategy == "mv" else "answer"):
            given_answer = result["final_answer"] if args.strategy == "mv" else result["answer"]
            is_correct = str(given_answer).lower() == correct_answer
            solver.update_stats(is_correct)
        
        # Prepare result record
        record = {
            "problem_id": idx,
            "question": problem["question"],
            "facts": problem.get("facts", []),
            "strategy": args.strategy,
            "correct_answer": correct_answer,
            "is_correct": is_correct,
            **result
        }
        results.append(record)
        
        print(f"Answer: {result.get('final_answer' if args.strategy == 'mv' else 'answer')}")
        print(f"Correct answer: {correct_answer}")
        print(f"Verification: {'CORRECT' if is_correct else 'INCORRECT'}")
        
        if args.strategy == "mv":
            print(f"Top prompts used: {result['top_prompts']}")
        else:
            print(f"Prompt used: {result['prompt']} (tried {result['prompts_tried']} prompts)")
    
    # Save results
    if results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"log/StrategyQA_iap/results_{args.strategy}_{args.start}_{args.end}_acc{solver.stats['accuracy']:.2f}%.json"
        
        output = {
            "results": results,
            "statistics": solver.stats
        }
        
        with open(filename, "w", encoding="utf-8") as f:
            json.dump(output, f, indent=2, ensure_ascii=False)
        
        print(f"\n{'='*50}\nFinal Statistics\n{'='*50}")
        print(f"Results saved to {filename}")
        print(f"Total problems processed: {solver.stats['total_problems']}")
        print(f"Correct answers: {solver.stats['correct_answers']}")
        print(f"Incorrect answers: {solver.stats['incorrect_answers']}")
        print(f"Overall accuracy: {solver.stats['accuracy']:.2f}%")
        print(f"{'='*50}\n")

if __name__ == "__main__":
    asyncio.run(main())